{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "73bd968b-d970-4a05-94ef-4e7abf990827",
   "metadata": {},
   "source": [
    "Chapter 23\n",
    "\n",
    "# 求解线性方程\n",
    "Book_3《数学要素》 | 鸢尾花书：从加减乘除到机器学习 (第二版)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d17de955-9c02-4f23-bec2-48dfdbddcb7f",
   "metadata": {},
   "source": [
    "这段代码通过不同的方法求解了一个线性方程组$Ax = b$，其中矩阵$A = \\begin{bmatrix} 1 & 1 \\\\ 2 & 4 \\end{bmatrix}$，向量$b = \\begin{bmatrix} 35 \\\\ 94 \\end{bmatrix}$。具体来说，代码依次使用矩阵求逆、直接解法和符号求解三种方式来得到解$x$。\n",
    "\n",
    "1. **矩阵求逆法**：首先计算矩阵$A$的逆矩阵$A^{-1}$，然后通过矩阵乘法$x = A^{-1}b$得到解$x$。这是基于线性方程组的一般解法$x = A^{-1}b$。\n",
    "   \n",
    "2. **NumPy的直接解法**：使用NumPy中的`solve`函数求解该线性方程组。该函数内部优化了求解过程，通常比求逆计算更为高效。\n",
    "\n",
    "3. **符号求解**：接着引入SymPy库，通过符号变量$x_1$和$x_2$构建方程$x_1 + x_2 = 35$和$2x_1 + 4x_2 = 94$。SymPy的`solve`函数求得每个变量的解，并以字典形式输出。\n",
    "\n",
    "4. **线性方程组求解集**：最后，使用SymPy的`linsolve`函数求解方程组。该函数返回一个包含解的集合形式，适用于线性方程组的求解。\n",
    "\n",
    "综上，代码验证了方程组$x_1 + x_2 = 35$和$2x_1 + 4x_2 = 94$的解，通过数值方法和符号方法都得到相同的解$x = \\begin{bmatrix} 23 \\\\ 12 \\end{bmatrix}$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c878bdb8-7303-4c06-b240-364b23d218c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43132808-16e6-407c-be98-80e247ccae75",
   "metadata": {},
   "source": [
    "## 定义矩阵A和向量b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c7c1f3d4-8cb0-4af5-a4a4-1c004a00dd48",
   "metadata": {},
   "outputs": [],
   "source": [
    "A = np.array([[1,1], # 系数矩阵\n",
    "              [2,4]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9db9cdf1-e747-4e5e-be24-50147d6b9e8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "b = np.array([[35],  # 常数项向量\n",
    "              [94]])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b74f5ab-eae6-4eff-8814-b5c34dd9c18e",
   "metadata": {},
   "source": [
    "## 通过矩阵的逆求解方程组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e1f20913-4913-4363-9a58-5173786cb371",
   "metadata": {},
   "outputs": [],
   "source": [
    "A_inv = np.linalg.inv(A) # 计算矩阵A的逆"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4a633470-ee4f-49b5-9827-3afa802b087e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[23.]\n",
      " [12.]]\n"
     ]
    }
   ],
   "source": [
    "x = A_inv@b # 计算解向量x\n",
    "print(x) # 输出解向量"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecda8f48-03bb-41ac-88bb-b45cfeaadc9f",
   "metadata": {},
   "source": [
    "## 使用NumPy中的solve函数求解方程组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8b6520bf-249b-4aa0-a547-b8f1220a9c53",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[23.]\n",
      " [12.]]\n"
     ]
    }
   ],
   "source": [
    "x_ = np.linalg.solve(A,b) # 使用solve直接求解\n",
    "print(x_) # 输出解向量"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53e0da88-aecd-4508-9d06-a942f9ad6db2",
   "metadata": {},
   "source": [
    "## 使用SymPy符号计算解"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "da4fdff4-02d1-47d7-9e9a-7c4305f3c7d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sympy import *\n",
    "x1, x2 = symbols(['x1', 'x2']) # 定义符号变量x1, x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cedde266-0648-4d6d-a3c7-4417fd37aa14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{x1: 23, x2: 12}\n"
     ]
    }
   ],
   "source": [
    "sol = solve([x1 + x2 - 35, 2*x1 + 4*x2 - 94], [x1, x2]) # 使用solve求解方程组\n",
    "print(sol) # 输出解的字典形式"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3077383a-a387-4f61-809a-3d68ba7e3d2a",
   "metadata": {},
   "source": [
    "## 使用linsolve求解线性方程组"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f1c6d415-70f7-4e1d-8385-bcd72ef176f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{(23, 12)}\n"
     ]
    }
   ],
   "source": [
    "from sympy.solvers.solveset import linsolve\n",
    "sol_ = linsolve([x1 + x2 - 35, 2*x1 + 4*x2 - 94], [x1, x2]) # 使用linsolve求解\n",
    "print(sol_) # 输出解的集合形式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a80909-2aac-49ed-bb7a-f8cc6b80ee7d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecd322f4-f919-4be2-adc3-69d28ef25e69",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
